import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, Adamax
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_loader import *

import numpy as np
import time
from torch.optim import lr_scheduler
from torchvision import models
import json

from MoE_RIM import MoE_RIM

def print_model_info(model):
    print("\n=== 模型详细信息 ===")
    print(f"模型结构：\n{model}")
    print(f"\n参数数量：{sum(p.numel() for p in model.parameters()):,}")
    print("\n各层详情：")
    for name, module in model.named_children():
        print(f"{name}: {module}")
        if hasattr(module, 'weight'):
            print(f"  weight shape: {module.weight.shape}")
        if isinstance(module, nn.ModuleList):
            for i, expert in enumerate(module):
                print(f"  Expert {i}: {expert}")


def train_model(train_loader, val_loader, model, n_way, k_shot, n_query, lr = 1e-5, epochs=100, kl_weight=1.0, args = None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Optimizer and loss
    optimizer = Adamax(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    best_val_acc = 0.0
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_acc = 0.0
        train_kl_b = 0.0
        train_kl_g = 0.0
        train_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for support_set, support_labels, query_set, query_labels in pbar:
            support_set = support_set.to(device)
            if args['model'] == 'BDC':
                support_set = support_set.view(
                    support_set.shape[0], 
                    n_way, 
                    k_shot, 
                    *support_set.shape[2:]
                )
            support_labels = support_labels.to(device)
            query_set = query_set.to(device)
            if args['model'] == 'BDC':
                query_set = query_set.view(
                    query_set.shape[0],
                    n_way,
                    int(query_set.shape[1] / n_way),
                    *query_set.shape[2:]
                )
            query_labels = query_labels.to(device)
            
            # Forward pass - model now returns (logits, kl_b, kl_g, acc)
            logits, kl_b, kl_g = model(support_set, query_set, support_labels)
            
            
            # Reshape logits and targets for loss calculation
            logits = logits.view(-1, n_way)
            targets = query_labels.view(-1)
            
            # Compute losses
            ce_loss = criterion(logits, targets)
            #print('ce loss:', ce_loss)
            #print("kl_b:", kl_b)
            #print("kl_g:", kl_g)

            total_loss = ce_loss + args['lambda1'] * kl_b + args['lambda2'] *kl_g

            _, predicted = torch.max(logits, 1)

            correct = (predicted == targets).sum().item()
            acc = correct / targets.size(0)
            #print("acc:", acc)
            
            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            # Update metrics
            train_loss += total_loss.item()
            train_acc += acc

            train_batches += 1
            
            pbar.set_postfix({
                'train_loss': train_loss / train_batches,
                'train_acc': train_acc / train_batches,
            })
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_acc = 0.0
        val_kl_b = 0.0
        val_kl_g = 0.0
        val_batches = 0
        
        with torch.no_grad():
            for support_set, support_labels, query_set, query_labels in val_loader:
                support_set = support_set.to(device)
                support_labels = support_labels.to(device)
                if args['model'] == 'BDC':
                    support_set = support_set.view(
                        support_set.shape[0], 
                        n_way, 
                        k_shot, 
                        *support_set.shape[2:]
                    )
                query_set = query_set.to(device)
                query_labels = query_labels.to(device)
                if args['model'] == 'BDC':
                    query_set = query_set.view(
                        query_set.shape[0],
                        n_way,
                        int(query_set.shape[1] / n_way),
                        *query_set.shape[2:]
                    )   
                
                logits, kl_b, kl_g = model(support_set, query_set, support_labels)
                logits = logits.view(-1, n_way)
                targets = query_labels.view(-1)
                
                ce_loss = criterion(logits, targets)
                total_loss = ce_loss + args['lambda1'] * kl_b + args['lambda2'] *kl_g
                _, predicted = torch.max(logits, 1)
                #print(predicted)
                correct = (predicted == targets).sum().item()
                acc = correct / targets.size(0)
                
                val_loss += total_loss.item()
                val_acc += acc

                val_batches += 1
        
        # Print epoch statistics
        print(f"\nEpoch {epoch+1}:")
        print(f"Train Loss: {train_loss/train_batches:.4f} | Train Acc: {train_acc/train_batches:.4f}")
        print(f"Val Acc: {val_acc/val_batches:.4f}")
        
        # Save best model
        if val_acc/val_batches > best_val_acc:
            best_val_acc = val_acc/val_batches
            exps = args['experts']
            cla = args['classifier']

            torch.save(model.state_dict(), f'lr_{lr}_{n_way}way_{k_shot}shot_{exps}experts_{cla}_best_model.pth')
            print("Saved new best model")
    
    return model

def evaluate(model, test_loader, n_way, k_shot, n_query, args = None):
    device = next(model.parameters()).device
    model.eval()
    test_acc = 0.0
    test_kl_b = 0.0
    test_kl_g = 0.0
    test_batches = 0
    acc_list = []
    
    with torch.no_grad():
        for support_set, support_labels, query_set, query_labels in test_loader:
            support_set = support_set.to(device)
            support_labels = support_labels.to(device)
            if args['model'] == 'BDC':
                support_set = support_set.view(
                    support_set.shape[0], 
                    n_way, 
                    k_shot, 
                    *support_set.shape[2:]
                )
            query_set = query_set.to(device)
            query_labels = query_labels.to(device)
            if args['model'] == 'BDC':
                query_set = query_set.view(
                    query_set.shape[0],
                    n_way,
                    int(query_set.shape[1] / n_way),
                    *query_set.shape[2:]
                )

            
            logits, kl_b, kl_g = model(support_set, query_set, support_labels)
            logits = logits.view(-1, n_way)
            targets = query_labels.view(-1)
            _, predicted = torch.max(logits, 1)
            correct = (predicted == targets).sum().item()
            acc = correct / targets.size(0)
            test_acc += acc
            acc_list.append(acc)
            test_batches += 1

    mean_acc = test_acc / test_batches
    std_acc = np.std(acc_list)
    print(f"\nTest Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    
    print(f"\nTest Accuracy: {test_acc/test_batches:.4f}")
    return mean_acc